// Copyright 2017 The Clspv Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Cloning.h"

#include "clspv/Option.h"
#include "spirv/unified1/spirv.hpp"

#include "Constants.h"
#include "Passes.h"

using namespace llvm;

#define DEBUG_TYPE "ReplaceLLVMIntrinsics"

namespace {
struct ReplaceLLVMIntrinsicsPass final : public ModulePass {
  static char ID;
  ReplaceLLVMIntrinsicsPass() : ModulePass(ID) {}

  bool runOnModule(Module &M) override;
  // TODO: update module-based funtions to work like function-based ones.
  // Except maybe lifetime intrinsics.
  bool runOnFunction(Function &F);
  bool replaceMemset(Module &M);
  bool replaceMemcpy(Module &M);
  bool removeLifetimeDeclarations(Module &M);
  bool replaceFshl(Function &F);
  bool replaceCountZeroes(Function &F, bool leading);
  bool replaceCopysign(Function &F);

  bool replaceCallsWithValue(Function &F,
                             std::function<Value *(CallInst *)> Replacer);

  SmallVector<Function *, 16> DeadFunctions;
};
} // namespace

char ReplaceLLVMIntrinsicsPass::ID = 0;
INITIALIZE_PASS(ReplaceLLVMIntrinsicsPass, "ReplaceLLVMIntrinsics",
                "Replace LLVM intrinsics Pass", false, false)

namespace clspv {
ModulePass *createReplaceLLVMIntrinsicsPass() {
  return new ReplaceLLVMIntrinsicsPass();
}
} // namespace clspv

bool ReplaceLLVMIntrinsicsPass::runOnModule(Module &M) {
  bool Changed = false;

  // Remove lifetime annotations first.  They could be using memset
  // and memcpy calls.
  Changed |= removeLifetimeDeclarations(M);
  Changed |= replaceMemset(M);
  Changed |= replaceMemcpy(M);

  for (auto &F : M) {
    Changed |= runOnFunction(F);
  }

  for (auto F : DeadFunctions) {
    F->eraseFromParent();
  }

  return Changed;
}

bool ReplaceLLVMIntrinsicsPass::runOnFunction(Function &F) {
  switch (F.getIntrinsicID()) {
  case Intrinsic::fshl:
    return replaceFshl(F);
  case Intrinsic::copysign:
    return replaceCopysign(F);
  case Intrinsic::ctlz:
    return replaceCountZeroes(F, true);
  case Intrinsic::cttz:
    return replaceCountZeroes(F, false);

  default:
    break;
  }

  return false;
}

bool ReplaceLLVMIntrinsicsPass::replaceCallsWithValue(
    Function &F, std::function<Value *(CallInst *)> Replacer) {
  SmallVector<Instruction *, 8> ToRemove;
  for (auto &U : F.uses()) {
    if (auto Call = dyn_cast<CallInst>(U.getUser())) {
      auto replacement = Replacer(Call);
      if (replacement != nullptr && replacement != Call) {
        Call->replaceAllUsesWith(replacement);
        ToRemove.push_back(Call);
      }
    }
  }

  for (auto inst : ToRemove) {
    inst->eraseFromParent();
  }

  DeadFunctions.push_back(&F);

  return !ToRemove.empty();
}

bool ReplaceLLVMIntrinsicsPass::replaceFshl(Function &F) {
  return replaceCallsWithValue(F, [](CallInst *call) {
    auto arg_hi = call->getArgOperand(0);
    auto arg_lo = call->getArgOperand(1);
    auto arg_shift = call->getArgOperand(2);

    // Validate argument types.
    auto type = arg_hi->getType();
    if ((type->getScalarSizeInBits() != 8) &&
        (type->getScalarSizeInBits() != 16) &&
        (type->getScalarSizeInBits() != 32) &&
        (type->getScalarSizeInBits() != 64)) {
      return static_cast<Value *>(nullptr);
    }

    // We shift the bottom bits of the first argument up, the top bits of the
    // second argument down, and then OR the two shifted values.
    IRBuilder<> builder(call);

    // The shift amount is treated modulo the element size.
    auto mod_mask = ConstantInt::get(type, type->getScalarSizeInBits() - 1);
    auto shift_amount = builder.CreateAnd(arg_shift, mod_mask);

    // Calculate the amount by which to shift the second argument down.
    auto scalar_size = ConstantInt::get(type, type->getScalarSizeInBits());
    auto down_amount = builder.CreateSub(scalar_size, shift_amount);

    // Shift the two arguments and OR the results together.
    auto hi_bits = builder.CreateShl(arg_hi, shift_amount);
    auto lo_bits = builder.CreateLShr(arg_lo, down_amount);
    return builder.CreateOr(lo_bits, hi_bits);
  });
}

bool ReplaceLLVMIntrinsicsPass::replaceMemset(Module &M) {
  bool Changed = false;
  auto Layout = M.getDataLayout();

  for (auto &F : M) {
    if (F.getName().startswith("llvm.memset")) {
      SmallVector<CallInst *, 8> CallsToReplace;

      for (auto U : F.users()) {
        if (auto CI = dyn_cast<CallInst>(U)) {
          auto Initializer = dyn_cast<ConstantInt>(CI->getArgOperand(1));

          // We only handle cases where the initializer is a constant int that
          // is 0.
          if (!Initializer || (0 != Initializer->getZExtValue())) {
            Initializer->print(errs());
            llvm_unreachable("Unhandled llvm.memset.* instruction that had a "
                             "non-0 initializer!");
          }

          CallsToReplace.push_back(CI);
        }
      }

      for (auto CI : CallsToReplace) {
        auto NewArg = CI->getArgOperand(0);
        auto Bitcast = dyn_cast<BitCastInst>(NewArg);
        if (Bitcast != nullptr) {
          NewArg = Bitcast->getOperand(0);
        }

        auto NumBytes = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
        auto Ty = NewArg->getType();
        auto PointeeTy = Ty->getPointerElementType();
        auto Zero = Constant::getNullValue(PointeeTy);

        const auto num_stores = NumBytes / Layout.getTypeAllocSize(PointeeTy);
        assert((NumBytes == num_stores * Layout.getTypeAllocSize(PointeeTy)) &&
               "Null memset can't be divided evenly across multiple stores.");
        assert((num_stores & 0xFFFFFFFF) == num_stores);

        // Generate the first store.
        new StoreInst(Zero, NewArg, CI);

        // Generate subsequent stores, but only if needed.
        if (num_stores) {
          auto I32Ty = Type::getInt32Ty(M.getContext());
          auto One = ConstantInt::get(I32Ty, 1);
          auto Ptr = NewArg;
          for (uint32_t i = 1; i < num_stores; i++) {
            Ptr = GetElementPtrInst::Create(PointeeTy, Ptr, {One}, "", CI);
            new StoreInst(Zero, Ptr, CI);
          }
        }

        CI->eraseFromParent();

        if (Bitcast != nullptr) {
          Bitcast->eraseFromParent();
        }
      }
    }
  }

  return Changed;
}

bool ReplaceLLVMIntrinsicsPass::replaceMemcpy(Module &M) {
  bool Changed = false;
  auto Layout = M.getDataLayout();

  // Unpack source and destination types until we find a matching
  // element type.  Count the number of levels we unpack for the
  // source and destination types.  So far this only works for
  // array types, but could be generalized to other regular types
  // like vectors.
  auto match_types = [&Layout](CallInst &CI, uint64_t Size, Type **DstElemTy,
                               Type **SrcElemTy, unsigned *NumDstUnpackings,
                               unsigned *NumSrcUnpackings) {
    auto descend_type = [](Type *InType) {
      Type *OutType = InType;
      if (OutType->isStructTy()) {
        OutType = OutType->getStructElementType(0);
      } else if (OutType->isArrayTy()) {
        OutType = OutType->getArrayElementType();
      } else if (auto vec_type = dyn_cast<VectorType>(OutType)) {
        OutType = vec_type->getElementType();
      } else {
        assert(false && "Don't know how to descend into type");
      }

      return OutType;
    };

    while (*SrcElemTy != *DstElemTy) {
      auto SrcElemSize = Layout.getTypeSizeInBits(*SrcElemTy);
      auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy);
      if (SrcElemSize >= DstElemSize) {
        *SrcElemTy = descend_type(*SrcElemTy);
        (*NumSrcUnpackings)++;
      } else if (DstElemSize >= SrcElemSize) {
        *DstElemTy = descend_type(*DstElemTy);
        (*NumDstUnpackings)++;
      } else {
        errs() << "Don't know how to unpack types for memcpy: " << CI
               << "\ngot to: " << **DstElemTy << " vs " << **SrcElemTy << "\n";
        assert(false && "Don't know how to unpack these types");
      }
    }

    auto DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
    while (Size < DstElemSize) {
      *DstElemTy = descend_type(*DstElemTy);
      *SrcElemTy = descend_type(*SrcElemTy);
      (*NumDstUnpackings)++;
      (*NumSrcUnpackings)++;
      DstElemSize = Layout.getTypeSizeInBits(*DstElemTy) / 8;
    }
  };

  SmallPtrSet<Instruction *, 8> BitCastsToForget;
  for (auto &F : M) {
    if (F.getName().startswith("llvm.memcpy")) {
      SmallVector<CallInst *, 8> CallsToReplaceWithSpirvCopyMemory;

      for (auto U : F.users()) {
        if (auto CI = dyn_cast<CallInst>(U)) {
          assert(isa<BitCastOperator>(CI->getArgOperand(0)));
          auto Dst =
              dyn_cast<BitCastOperator>(CI->getArgOperand(0))->getOperand(0);

          assert(isa<BitCastOperator>(CI->getArgOperand(1)));
          auto Src =
              dyn_cast<BitCastOperator>(CI->getArgOperand(1))->getOperand(0);

          // The original type of Dst we get from the argument to the bitcast
          // instruction.
          auto DstTy = Dst->getType();
          assert(DstTy->isPointerTy());

          // The original type of Src we get from the argument to the bitcast
          // instruction.
          auto SrcTy = Src->getType();
          assert(SrcTy->isPointerTy());

          // Check that the size is a constant integer.
          assert(isa<ConstantInt>(CI->getArgOperand(2)));
          auto Size =
              dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();

          auto DstElemTy = DstTy->getPointerElementType();
          auto SrcElemTy = SrcTy->getPointerElementType();
          unsigned NumDstUnpackings = 0;
          unsigned NumSrcUnpackings = 0;
          match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
                      &NumSrcUnpackings);

          // Check that the pointee types match.
          assert(DstElemTy == SrcElemTy);

          auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;
          (void)DstElemSize;

          // Check that the size is a multiple of the size of the pointee type.
          assert(Size % DstElemSize == 0);

          auto Alignment = cast<MemIntrinsic>(CI)->getDestAlignment();
          auto TypeAlignment = Layout.getABITypeAlignment(DstElemTy);
          (void)Alignment;
          (void)TypeAlignment;

          // Check that the alignment is at least the alignment of the pointee
          // type.
          assert(Alignment >= TypeAlignment);

          // Check that the alignment is a multiple of the alignment of the
          // pointee type.
          assert(0 == (Alignment % TypeAlignment));

          // Check that volatile is a constant.
          assert(isa<ConstantInt>(CI->getArgOperand(3)));

          CallsToReplaceWithSpirvCopyMemory.push_back(CI);
        }
      }

      for (auto CI : CallsToReplaceWithSpirvCopyMemory) {
        auto Arg0 = dyn_cast<BitCastOperator>(CI->getArgOperand(0));
        auto Arg1 = dyn_cast<BitCastOperator>(CI->getArgOperand(1));
        auto Arg3 = dyn_cast<ConstantInt>(CI->getArgOperand(3));

        auto I32Ty = Type::getInt32Ty(M.getContext());
        auto DstAlignment =
            ConstantInt::get(I32Ty, cast<MemCpyInst>(CI)->getDestAlignment());
        auto SrcAlignment =
            ConstantInt::get(I32Ty, cast<MemCpyInst>(CI)->getSourceAlignment());
        auto Volatile = ConstantInt::get(I32Ty, Arg3->getZExtValue());

        auto Dst = Arg0->getOperand(0);
        auto Src = Arg1->getOperand(0);

        auto DstElemTy = Dst->getType()->getPointerElementType();
        auto SrcElemTy = Src->getType()->getPointerElementType();
        unsigned NumDstUnpackings = 0;
        unsigned NumSrcUnpackings = 0;
        auto Size = dyn_cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
        match_types(*CI, Size, &DstElemTy, &SrcElemTy, &NumDstUnpackings,
                    &NumSrcUnpackings);
        auto SPIRVIntrinsic = clspv::CopyMemoryFunction();

        auto DstElemSize = Layout.getTypeSizeInBits(DstElemTy) / 8;

        IRBuilder<> Builder(CI);

        if (NumSrcUnpackings == 0 && NumDstUnpackings == 0) {
          SmallVector<Type *, 5> param_tys = {Dst->getType(), Src->getType(),
                                              I32Ty, I32Ty};
          SmallVector<Value *, 5> param_values = {Dst, Src, DstAlignment};
          if (clspv::Option::SpvVersion() >=
              clspv::Option::SPIRVVersion::SPIRV_1_4) {
            param_tys.push_back(I32Ty);
            param_values.push_back(SrcAlignment);
          }
          param_values.push_back(Volatile);
          auto NewFType =
              FunctionType::get(F.getReturnType(), param_tys, false);
          auto NewF =
              Function::Create(NewFType, F.getLinkage(), SPIRVIntrinsic, &M);
          Builder.CreateCall(NewF, param_values, "");
        } else {
          auto Zero = ConstantInt::get(I32Ty, 0);
          SmallVector<Value *, 3> SrcIndices;
          SmallVector<Value *, 3> DstIndices;
          // Make unpacking indices.
          for (unsigned unpacking = 0; unpacking < NumSrcUnpackings;
               ++unpacking) {
            SrcIndices.push_back(Zero);
          }
          for (unsigned unpacking = 0; unpacking < NumDstUnpackings;
               ++unpacking) {
            DstIndices.push_back(Zero);
          }
          // Add a placeholder for the final index.
          SrcIndices.push_back(Zero);
          DstIndices.push_back(Zero);

          // Build the function and function type only once.
          FunctionType *NewFType = nullptr;
          Function *NewF = nullptr;

          IRBuilder<> Builder(CI);
          for (unsigned i = 0; i < Size / DstElemSize; ++i) {
            auto Index = ConstantInt::get(I32Ty, i);
            SrcIndices.back() = Index;
            DstIndices.back() = Index;

            // Avoid the builder for Src in order to prevent the folder from
            // creating constant expressions for constant memcpys.
            auto SrcElemPtr =
                GetElementPtrInst::CreateInBounds(Src, SrcIndices, "", CI);
            auto DstElemPtr = Builder.CreateGEP(Dst, DstIndices);
            SmallVector<Type *, 5> param_tys = {
                DstElemPtr->getType(), SrcElemPtr->getType(), I32Ty, I32Ty};
            SmallVector<Value *, 5> param_values = {DstElemPtr, SrcElemPtr,
                                                    DstAlignment};
            if (clspv::Option::SpvVersion() >=
                clspv::Option::SPIRVVersion::SPIRV_1_4) {
              param_tys.push_back(I32Ty);
              param_values.push_back(SrcAlignment);
            }
            param_values.push_back(Volatile);
            NewFType =
                NewFType != nullptr
                    ? NewFType
                    : FunctionType::get(F.getReturnType(), param_tys, false);
            NewF = NewF != nullptr ? NewF
                                   : Function::Create(NewFType, F.getLinkage(),
                                                      SPIRVIntrinsic, &M);
            Builder.CreateCall(NewF, param_values, "");
          }
        }

        // Erase the call.
        CI->eraseFromParent();

        // Erase the bitcasts.  A particular bitcast might be used
        // in more than one memcpy, so defer actual deleting until later.
        if (isa<BitCastInst>(Arg0))
          BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg0));
        if (isa<BitCastInst>(Arg1))
          BitCastsToForget.insert(dyn_cast<BitCastInst>(Arg1));
      }
    }
  }
  for (auto *Inst : BitCastsToForget) {
    Inst->eraseFromParent();
  }

  return Changed;
}

bool ReplaceLLVMIntrinsicsPass::removeLifetimeDeclarations(Module &M) {
  // SPIR-V OpLifetimeStart and OpLifetimeEnd require Kernel capability.
  // Vulkan doesn't support that, so remove all lifteime bounds declarations.

  bool Changed = false;

  SmallVector<Function *, 2> WorkList;
  for (auto &F : M) {
    if (F.getName().startswith("llvm.lifetime.")) {
      WorkList.push_back(&F);
    }
  }

  for (auto *F : WorkList) {
    Changed = true;
    // Copy users to avoid modifying the list in place.
    SmallVector<User *, 8> users(F->users());
    for (auto U : users) {
      if (auto *CI = dyn_cast<CallInst>(U)) {
        CI->eraseFromParent();
      }
    }
    F->eraseFromParent();
  }

  return Changed;
}

bool ReplaceLLVMIntrinsicsPass::replaceCountZeroes(Function &F, bool leading) {
  if (!isa<IntegerType>(F.getReturnType()->getScalarType()))
    return false;

  auto bitwidth = F.getReturnType()->getScalarSizeInBits();
  if (bitwidth == 32 || bitwidth > 64)
    return false;

  return replaceCallsWithValue(F, [&F, bitwidth, leading](CallInst *Call) {
    auto c_false = ConstantInt::getFalse(Call->getContext());
    auto in = Call->getArgOperand(0);
    IRBuilder<> builder(Call);
    auto ty = Call->getType()->getWithNewBitWidth(32);
    auto c32 = ConstantInt::get(ty, 32);
    auto func_32bit = Intrinsic::getDeclaration(
        F.getParent(), leading ? Intrinsic::ctlz : Intrinsic::cttz, ty);
    if (bitwidth < 32) {
      // Extend the input to 32-bits and perform a clz/ctz.
      auto zext = builder.CreateZExt(in, ty);
      Value *call_input = zext;
      if (!leading) {
        // Or the extended input value with a constant that caps the max to the
        // right bitwidth (e.g. 256 for i8 and 65536 for i16).
        auto mask = ConstantInt::get(ty, 1 << bitwidth);
        call_input = builder.CreateOr(zext, mask);
      }
      auto call = builder.CreateCall(func_32bit->getFunctionType(), func_32bit,
                                     {call_input, c_false});
      Value *tmp = call;
      if (leading) {
        // Clz is implemented as 31 - FindUMsb(|zext|), so adjust the result
        // the right bitwidth.
        auto sub_const = ConstantInt::get(ty, 32 - bitwidth);
        tmp = builder.CreateSub(call, sub_const);
      }
      // Truncate the intermediate result to the right size.
      return builder.CreateTrunc(tmp, Call->getType());
    } else {
      // Perform a 32-bit version of clz/ctz on each half of the 64-bit input.
      auto lshr = builder.CreateLShr(in, 32);
      auto top_bits = builder.CreateTrunc(lshr, ty);
      auto bot_bits = builder.CreateTrunc(in, ty);
      auto top_func = builder.CreateCall(func_32bit->getFunctionType(),
                                         func_32bit, {top_bits, c_false});
      auto bot_func = builder.CreateCall(func_32bit->getFunctionType(),
                                         func_32bit, {bot_bits, c_false});
      Value *tmp = nullptr;
      if (leading) {
        // For clz, if clz(top) is 32, return 32 + clz(bot).
        auto cmp = builder.CreateICmpEQ(top_func, c32);
        auto adjust = builder.CreateAdd(bot_func, c32);
        tmp = builder.CreateSelect(cmp, adjust, top_func);
      } else {
        // For ctz, if clz(bot) is 32, return 32 + ctz(top)
        auto bot_cmp = builder.CreateICmpEQ(bot_func, c32);
        auto adjust = builder.CreateAdd(top_func, c32);
        tmp = builder.CreateSelect(bot_cmp, adjust, bot_func);
      }
      // Extend the intermediate result to the correct size.
      return builder.CreateZExt(tmp, Call->getType());
    }
  });
}

bool ReplaceLLVMIntrinsicsPass::replaceCopysign(Function &F) {
  return replaceCallsWithValue(F, [&F](CallInst *CI) {
    auto XValue = CI->getOperand(0);
    auto YValue = CI->getOperand(1);

    auto Ty = XValue->getType();

    Type *IntTy = Type::getIntNTy(F.getContext(), Ty->getScalarSizeInBits());
    if (auto vec_ty = dyn_cast<VectorType>(Ty)) {
      IntTy = FixedVectorType::get(
          IntTy, vec_ty->getElementCount().getKnownMinValue());
    }

    // Return X with the sign of Y

    // Sign bit masks
    auto SignBit = IntTy->getScalarSizeInBits() - 1;
    auto SignBitMask = 1 << SignBit;
    auto SignBitMaskValue = ConstantInt::get(IntTy, SignBitMask);
    auto NotSignBitMaskValue = ConstantInt::get(IntTy, ~SignBitMask);

    IRBuilder<> Builder(CI);

    // Extract sign of Y
    auto YInt = Builder.CreateBitCast(YValue, IntTy);
    auto YSign = Builder.CreateAnd(YInt, SignBitMaskValue);

    // Clear sign bit in X
    auto XInt = Builder.CreateBitCast(XValue, IntTy);
    XInt = Builder.CreateAnd(XInt, NotSignBitMaskValue);

    // Insert sign bit of Y into X
    auto NewXInt = Builder.CreateOr(XInt, YSign);

    // And cast back to floating-point
    return Builder.CreateBitCast(NewXInt, Ty);
  });
}
